{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6bYaCABobL5q"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "FlUw7tSKbtg4"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_-fogOi3K7nR"
      },
      "source": [
        "# Use TF1.x models in TF2 workflows\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/model_mapping\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/model_mapping.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/model_mapping.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/model_mapping.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7-GwECUqrkqT"
      },
      "source": [
        "This guide provides an overview and examples of a [modeling code shim](https://en.wikipedia.org/wiki/Shim_(computing)) that you can employ to use your existing TF1.x models in TF2 workflows such as eager execution, `tf.function`, and distribution strategies with minimal changes to your modeling code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k_ezCbogxaqt"
      },
      "source": [
        "## Scope of usage\n",
        "\n",
        "The shim described in this guide is designed for TF1.x models that rely on:\n",
        "1. `tf.compat.v1.get_variable` and `tf.compat.v1.variable_scope` to control variable creation and reuse, and\n",
        "1. Graph-collection based APIs such as `tf.compat.v1.global_variables()`, `tf.compat.v1.trainable_variables`, `tf.compat.v1.losses.get_regularization_losses()`, and `tf.compat.v1.get_collection()` to keep track of weights and regularization losses\n",
        "\n",
        "This includes most models built on top of `tf.compat.v1.layer`, `tf.contrib.layers` APIs, and [TensorFlow-Slim](https://github.com/google-research/tf-slim).\n",
        "\n",
        "The shim is **NOT** necessary for the following TF1.x models:\n",
        "\n",
        "1. Stand-alone Keras models that already track all of their trainable weights and regularization losses via `model.trainable_weights` and `model.losses` respectively.\n",
        "1. `tf.Module`s that already track all of their trainable weights via `module.trainable_variables`, and only create weights if they have not already been created.\n",
        "\n",
        "These models are likely to work in TF2 with eager execution and `tf.function`s out-of-the-box."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3OQNFp8zgV0C"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Import TensorFlow and other dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EG2n3-qlD5mA"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y -q tensorflow"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mVfR3MBvD9Sc"
      },
      "outputs": [],
      "source": [
        "# Install tf-nightly as the DeterministicRandomTestTool is available only in\n",
        "# Tensorflow 2.8\n",
        "\n",
        "!pip install -q tf-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PzkV-2cna823"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow.compat.v1 as v1\n",
        "import sys\n",
        "import numpy as np\n",
        "\n",
        "from contextlib import contextmanager"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ox4kn0DK8H0f"
      },
      "source": [
        "## The `track_tf1_style_variables` decorator\n",
        "\n",
        "The key shim described in this guide is `tf.compat.v1.keras.utils.track_tf1_style_variables`, a decorator that you can use within methods belonging to `tf.keras.layers.Layer` and `tf.Module` to track TF1.x-style weights and capture regularization losses.\n",
        "\n",
        "Decorating a `tf.keras.layers.Layer`'s  or `tf.Module`'s call methods with `tf.compat.v1.keras.utils.track_tf1_style_variables` allows variable creation and reuse via `tf.compat.v1.get_variable` (and by extension `tf.compat.v1.layers`) to work correctly inside of the decorated method rather than always creating a new variable on each call. It will also cause the layer or module to implicitly track any weights created or accessed via `get_variable` inside the decorated method.\n",
        "\n",
        "In addition to tracking the weights themselves under the standard\n",
        "`layer.variable`/`module.variable`/etc. properties, if the method belongs\n",
        "to a `tf.keras.layers.Layer`, then any regularization losses specified via the\n",
        "`get_variable` or `tf.compat.v1.layers` regularizer arguments will get\n",
        "tracked by the layer under the standard `layer.losses` property.\n",
        "\n",
        "This tracking mechanism enables using large classes of TF1.x-style model-forward-pass code inside of Keras layers or `tf.Module`s in TF2 even with TF2 behaviors enabled.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sq6IqZILmGmO"
      },
      "source": [
        "## Usage examples\n",
        "\n",
        "The usage examples below demonstrate the modeling shims used to decorate `tf.keras.layers.Layer` methods, but except where they are specifically interacting with Keras features they are applicable when decorating `tf.Module` methods as well."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YWGPh6KmkHq6"
      },
      "source": [
        "### Layer built with tf.compat.v1.get_variable\n",
        "\n",
        "Imagine you have a layer implemented directly on top of `tf.compat.v1.get_variable` as follows:\n",
        "\n",
        "```python\n",
        "def dense(self, inputs, units):\n",
        "  out = inputs\n",
        "  with tf.compat.v1.variable_scope(\"dense\"):\n",
        "    # The weights are created with a `regularizer`,\n",
        "    kernel = tf.compat.v1.get_variable(\n",
        "        shape=[out.shape[-1], units],\n",
        "        regularizer=tf.keras.regularizers.L2(),\n",
        "        initializer=tf.compat.v1.initializers.glorot_normal,\n",
        "        name=\"kernel\")\n",
        "    bias = tf.compat.v1.get_variable(\n",
        "        shape=[units,],\n",
        "        initializer=tf.compat.v1.initializers.zeros,\n",
        "        name=\"bias\")\n",
        "    out = tf.linalg.matmul(out, kernel)\n",
        "    out = tf.compat.v1.nn.bias_add(out, bias)\n",
        "  return out\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6sZWU7JSok2n"
      },
      "source": [
        "Use the shim to turn it into a layer and call it on inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q3eKkcKtS_N4"
      },
      "outputs": [],
      "source": [
        "class DenseLayer(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs):\n",
        "    out = inputs\n",
        "    with tf.compat.v1.variable_scope(\"dense\"):\n",
        "      # The weights are created with a `regularizer`,\n",
        "      # so the layer should track their regularization losses\n",
        "      kernel = tf.compat.v1.get_variable(\n",
        "          shape=[out.shape[-1], self.units],\n",
        "          regularizer=tf.keras.regularizers.L2(),\n",
        "          initializer=tf.compat.v1.initializers.glorot_normal,\n",
        "          name=\"kernel\")\n",
        "      bias = tf.compat.v1.get_variable(\n",
        "          shape=[self.units,],\n",
        "          initializer=tf.compat.v1.initializers.zeros,\n",
        "          name=\"bias\")\n",
        "      out = tf.linalg.matmul(out, kernel)\n",
        "      out = tf.compat.v1.nn.bias_add(out, bias)\n",
        "    return out\n",
        "\n",
        "layer = DenseLayer(10)\n",
        "x = tf.random.normal(shape=(8, 20))\n",
        "layer(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JqXAlWnYgwcq"
      },
      "source": [
        "Access the tracked variables and the captured regularization losses like a standard Keras layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZNz5HmkXg0B5"
      },
      "outputs": [],
      "source": [
        "layer.trainable_variables\n",
        "layer.losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W0z9GmRlhM9X"
      },
      "source": [
        "To see that the weights get reused each time you call the layer, set all the weights to zero and call the layer again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZJ4vOu2Rf-I2"
      },
      "outputs": [],
      "source": [
        "print(\"Resetting variables to zero:\", [var.name for var in layer.trainable_variables])\n",
        "\n",
        "for var in layer.trainable_variables:\n",
        "  var.assign(var * 0.0)\n",
        "\n",
        "# Note: layer.losses is not a live view and\n",
        "# will get reset only at each layer call\n",
        "print(\"layer.losses:\", layer.losses)\n",
        "print(\"calling layer again.\")\n",
        "out = layer(x)\n",
        "print(\"layer.losses: \", layer.losses)\n",
        "out"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WwEprtA-lOh6"
      },
      "source": [
        "You can use the converted layer directly in Keras functional model construction as well."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7E7ZCINHlaHU"
      },
      "outputs": [],
      "source": [
        "inputs = tf.keras.Input(shape=(20))\n",
        "outputs = DenseLayer(10)(inputs)\n",
        "model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
        "\n",
        "x = tf.random.normal(shape=(8, 20))\n",
        "model(x)\n",
        "\n",
        "# Access the model variables and regularization losses\n",
        "model.weights\n",
        "model.losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ew5TTEyZkZGU"
      },
      "source": [
        "### Model built with `tf.compat.v1.layers`\n",
        "\n",
        "Imagine you have a layer or model implemented directly on top of `tf.compat.v1.layers` as follows:\n",
        "\n",
        "```python\n",
        "def model(self, inputs, units):\n",
        "  with tf.compat.v1.variable_scope('model'):\n",
        "    out = tf.compat.v1.layers.conv2d(\n",
        "        inputs, 3, 3,\n",
        "        kernel_regularizer=\"l2\")\n",
        "    out = tf.compat.v1.layers.flatten(out)\n",
        "    out = tf.compat.v1.layers.dense(\n",
        "        out, units,\n",
        "        kernel_regularizer=\"l2\")\n",
        "    return out\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gZolXllfpVx6"
      },
      "source": [
        "Use the shim to turn it into a layer and call it on inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cBpfSHWTTTCv"
      },
      "outputs": [],
      "source": [
        "class CompatV1LayerModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = tf.compat.v1.layers.conv2d(\n",
        "          inputs, 3, 3,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      out = tf.compat.v1.layers.flatten(out)\n",
        "      out = tf.compat.v1.layers.dense(\n",
        "          out, self.units,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      return out\n",
        "\n",
        "layer = CompatV1LayerModel(10)\n",
        "x = tf.random.normal(shape=(8, 5, 5, 5))\n",
        "layer(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OkG9oLlblfK_"
      },
      "source": [
        "Warning: For safety reasons, make sure to put all `tf.compat.v1.layers` inside of a non-empty-string `variable_scope`. This is because `tf.compat.v1.layers` with auto-generated names will always auto-increment the name outside of any variable scope. This means the requested variable names will mismatch each time you call the layer/module. So, rather than reusing the already-made weights it will create a new set of variables every call."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zAVN6dy3p7ik"
      },
      "source": [
        "Access the tracked variables and captured regularization losses like a standard Keras layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HTRF99vJp7ik"
      },
      "outputs": [],
      "source": [
        "layer.trainable_variables\n",
        "layer.losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kkNuEcyIp7ik"
      },
      "source": [
        "To see that the weights get reused each time you call the layer, set all the weights to zero and call the layer again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4dk4XScdp7il"
      },
      "outputs": [],
      "source": [
        "print(\"Resetting variables to zero:\", [var.name for var in layer.trainable_variables])\n",
        "\n",
        "for var in layer.trainable_variables:\n",
        "  var.assign(var * 0.0)\n",
        "\n",
        "out = layer(x)\n",
        "print(\"layer.losses: \", layer.losses)\n",
        "out"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7zD3a8PKzU7S"
      },
      "source": [
        "You can use the converted layer directly in Keras functional model construction as well."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q88BgBCup7il"
      },
      "outputs": [],
      "source": [
        "inputs = tf.keras.Input(shape=(5, 5, 5))\n",
        "outputs = CompatV1LayerModel(10)(inputs)\n",
        "model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
        "\n",
        "x = tf.random.normal(shape=(8, 5, 5, 5))\n",
        "model(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2cioB6Zap7il"
      },
      "outputs": [],
      "source": [
        "# Access the model variables and regularization losses\n",
        "model.weights\n",
        "model.losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NBNODOx9ly6r"
      },
      "source": [
        "### Capture batch normalization updates and model `training` args\n",
        "\n",
        "In TF1.x, you perform batch normalization like this:\n",
        "\n",
        "```python\n",
        "  x_norm = tf.compat.v1.layers.batch_normalization(x, training=training)\n",
        "\n",
        "  # ...\n",
        "\n",
        "  update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)\n",
        "  train_op = optimizer.minimize(loss)\n",
        "  train_op = tf.group([train_op, update_ops])\n",
        "```\n",
        "Note that:\n",
        "1. The batch normalization moving average updates are tracked by `get_collection` which was called separately from the layer\n",
        "2. `tf.compat.v1.layers.batch_normalization` requires a `training` argument (generally called `is_training` when using TF-Slim batch normalization layers)\n",
        "\n",
        "In TF2, due to [eager execution](https://www.tensorflow.org/guide/eager) and automatic control dependencies, the batch normalization moving average updates will be executed right away. There is no need to separately collect them from the updates collection and add them as explicit control dependencies.\n",
        "\n",
        "Additionally, if you give your `tf.keras.layers.Layer`'s forward pass method a `training` argument, Keras will be able to pass the current training phase and any nested layers to it just like it does for any other layer. See the API docs for `tf.keras.Model` for more information on how Keras handles the `training` argument.\n",
        "\n",
        "If you are decorating `tf.Module` methods, you need to make sure to manually pass all `training` arguments as needed. However, the batch normalization moving average updates will still be applied automatically with no need for explicit control dependencies.\n",
        "\n",
        "The following code snippets demonstrate how to embed batch normalization layers in the shim and how using it in a Keras model works (applicable to `tf.keras.layers.Layer`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CjZE-J7mkS9p"
      },
      "outputs": [],
      "source": [
        "class CompatV1BatchNorm(tf.keras.layers.Layer):\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs, training=None):\n",
        "    print(\"Forward pass called with `training` =\", training)\n",
        "    with v1.variable_scope('batch_norm_layer'):\n",
        "      return v1.layers.batch_normalization(x, training=training)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NGuvvElmY-fu"
      },
      "outputs": [],
      "source": [
        "print(\"Constructing model\")\n",
        "inputs = tf.keras.Input(shape=(5, 5, 5))\n",
        "outputs = CompatV1BatchNorm()(inputs)\n",
        "model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
        "\n",
        "print(\"Calling model in inference mode\")\n",
        "x = tf.random.normal(shape=(8, 5, 5, 5))\n",
        "model(x, training=False)\n",
        "\n",
        "print(\"Moving average variables before training: \",\n",
        "      {var.name: var.read_value() for var in model.non_trainable_variables})\n",
        "\n",
        "# Notice that when running TF2 and eager execution, the batchnorm layer directly\n",
        "# updates the moving averages while training without needing any extra control\n",
        "# dependencies\n",
        "print(\"calling model in training mode\")\n",
        "model(x, training=True)\n",
        "\n",
        "print(\"Moving average variables after training: \",\n",
        "      {var.name: var.read_value() for var in model.non_trainable_variables})\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gai4ikpmeRqR"
      },
      "source": [
        "### Variable-scope based variable reuse\n",
        "Any variable creations in the forward pass based on `get_variable` will maintain the same variable naming and reuse semantics that variable scopes have in TF1.x. This is true as long as you have at least one non-empty outer scope for any `tf.compat.v1.layers` with auto-generated names, as mentioned above.\n",
        "\n",
        "Note: Naming and reuse will be scoped to within a single layer/module instance. Calls to `get_variable` inside one shim-decorated layer or module will not be able to refer to variables created inside of layers or modules. You can get around this by using Python references to other variables directly if need be, rather than accessing variables via `get_variable`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6PzYZdX2nMVt"
      },
      "source": [
        "### Eager execution & `tf.function`\n",
        "\n",
        "As seen above, decorated methods for `tf.keras.layers.Layer` and `tf.Module` run inside of eager execution and are also compatible with `tf.function`. This means you can use [pdb](https://docs.python.org/3/library/pdb.html) and other interactive tools to step through your forward pass as it is running.\n",
        "\n",
        "Warning: Although it is perfectly safe to call your shim-decorated layer/module methods from *inside* of a `tf.function`, it is not safe to put `tf.function`s inside of your shim-decorated methods if those `tf.functions` contain `get_variable` calls. Entering a `tf.function` resets `variable_scope`s, which means the TF1.x-style variable-scope-based variable reuse that the shim mimics will break down in this setting."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aPytVgZWnShe"
      },
      "source": [
        "### Distribution strategies\n",
        "\n",
        "Calls to `get_variable` inside of `@track_tf1_style_variables`-decorated layer or module methods use standard `tf.Variable` variable creations under the hood. This means you can use them with the various distribution strategies available with `tf.distribute` such as `MirroredStrategy` and `TPUStrategy`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_DcK24FOA8A2"
      },
      "source": [
        "## Nesting `tf.Variable`s, `tf.Module`s, `tf.keras.layers` & `tf.keras.models` in decorated calls\n",
        "\n",
        "Decorating your layer call in `tf.compat.v1.keras.utils.track_tf1_style_variables` will only add automatic implicit tracking of variables created (and reused) via `tf.compat.v1.get_variable`. It will not capture weights directly created by `tf.Variable` calls, such as those used by typical Keras layers and most `tf.Module`s. This section describes how to handle these nested cases.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Azxza3bVOZlv"
      },
      "source": [
        "###  (Pre-existing usages) `tf.keras.layers` and `tf.keras.models`\n",
        "\n",
        "For pre-existing usages of nested Keras layers and models, use `tf.compat.v1.keras.utils.get_or_create_layer`. This is only recommended for easing migration of existing TF1.x nested Keras usages; new code should use explicit attribute setting as described below for tf.Variables and tf.Modules.\n",
        "\n",
        "To use `tf.compat.v1.keras.utils.get_or_create_layer`, wrap the code that constructs your nested model into a method, and pass it in to the method. Example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LN15TcRgHKsq"
      },
      "outputs": [],
      "source": [
        "class NestedModel(tf.keras.Model):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "\n",
        "  def build_model(self):\n",
        "    inp = tf.keras.Input(shape=(5, 5))\n",
        "    dense_layer = tf.keras.layers.Dense(\n",
        "        10, name=\"dense\", kernel_regularizer=\"l2\",\n",
        "        kernel_initializer=tf.compat.v1.ones_initializer())\n",
        "    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))\n",
        "    return model\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs):\n",
        "    # Get or create a nested model without assigning it as an explicit property\n",
        "    model = tf.compat.v1.keras.utils.get_or_create_layer(\n",
        "        \"dense_model\", self.build_model)\n",
        "    return model(inputs)\n",
        "\n",
        "layer = NestedModel(10)\n",
        "layer(tf.ones(shape=(5,5)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DgsKlltPHI8z"
      },
      "source": [
        "This method ensures that these nested layers are correctly reused and tracked by tensorflow. Note that the `@track_tf1_style_variables` decorator is still required on the appropriate method. The model builder method passed into `get_or_create_layer` (in this case, `self.build_model`), should take no arguments.\n",
        "\n",
        "Weights are tracked:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3zO5A78MJsqO"
      },
      "outputs": [],
      "source": [
        "assert len(layer.weights) == 2\n",
        "weights = {x.name: x for x in layer.variables}\n",
        "\n",
        "assert set(weights.keys()) == {\"dense/bias:0\", \"dense/kernel:0\"}\n",
        "\n",
        "layer.weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o3Xsi-JbKTuj"
      },
      "source": [
        "And regularization loss as well:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mdK5RGm5KW5C"
      },
      "outputs": [],
      "source": [
        "tf.add_n(layer.losses)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J_VRycQYJrXu"
      },
      "source": [
        "###  Incremental migration: `tf.Variables` and `tf.Modules`\n",
        "\n",
        "If you need to embed `tf.Variable` calls or `tf.Module`s in your decorated methods (for example, if you are following the incremental migration to non-legacy TF2 APIs described later in this guide), you still need to explicitly track these, with the following requirements:\n",
        "* Explicitly make sure that the variable/module/layer is only created once\n",
        "* Explicitly attach them as instance attributes just as you would when defining a [typical module or layer](https://www.tensorflow.org/guide/intro_to_modules#defining_models_and_layers_in_tensorflow)\n",
        "* Explicitly reuse the already-created object in follow-on calls\n",
        "\n",
        "This ensures that weights are not created new each call and are correctly reused. Additionally, this also ensures that existing weights and regularization losses get tracked.\n",
        "\n",
        "Here is an example of how this could look:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mrRPPoJ5ap5U"
      },
      "outputs": [],
      "source": [
        "class NestedLayer(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def __call__(self, inputs):\n",
        "    out = inputs\n",
        "    with tf.compat.v1.variable_scope(\"inner_dense\"):\n",
        "      # The weights are created with a `regularizer`,\n",
        "      # so the layer should track their regularization losses\n",
        "      kernel = tf.compat.v1.get_variable(\n",
        "          shape=[out.shape[-1], self.units],\n",
        "          regularizer=tf.keras.regularizers.L2(),\n",
        "          initializer=tf.compat.v1.initializers.glorot_normal,\n",
        "          name=\"kernel\")\n",
        "      bias = tf.compat.v1.get_variable(\n",
        "          shape=[self.units,],\n",
        "          initializer=tf.compat.v1.initializers.zeros,\n",
        "          name=\"bias\")\n",
        "      out = tf.linalg.matmul(out, kernel)\n",
        "      out = tf.compat.v1.nn.bias_add(out, bias)\n",
        "    return out\n",
        "\n",
        "class WrappedDenseLayer(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "    self.units = units\n",
        "    # Only create the nested tf.variable/module/layer/model\n",
        "    # once, and then reuse it each time!\n",
        "    self._dense_layer = NestedLayer(self.units)\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs):\n",
        "    with tf.compat.v1.variable_scope('outer'):\n",
        "      outputs = tf.compat.v1.layers.dense(inputs, 3)\n",
        "      outputs = tf.compat.v1.layers.dense(inputs, 4)\n",
        "      return self._dense_layer(outputs)\n",
        "\n",
        "layer = WrappedDenseLayer(10)\n",
        "\n",
        "layer(tf.ones(shape=(5, 5)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lo9h6wc6bmEF"
      },
      "source": [
        "Note that explicit tracking of the nested module is needed even though it is decorated with the `track_tf1_style_variables` decorator. This is because each module/layer with decorated methods has its own variable store associated with it. \n",
        "\n",
        "The weights are correctly tracked:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qt6USaTVbauM"
      },
      "outputs": [],
      "source": [
        "assert len(layer.weights) == 6\n",
        "weights = {x.name: x for x in layer.variables}\n",
        "\n",
        "assert set(weights.keys()) == {\"outer/inner_dense/bias:0\",\n",
        "                               \"outer/inner_dense/kernel:0\",\n",
        "                               \"outer/dense/bias:0\",\n",
        "                               \"outer/dense/kernel:0\",\n",
        "                               \"outer/dense_1/bias:0\",\n",
        "                               \"outer/dense_1/kernel:0\"}\n",
        "\n",
        "layer.trainable_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dHn-bJoNJw7l"
      },
      "source": [
        "As well as regularization loss:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pq5GFtXjJyut"
      },
      "outputs": [],
      "source": [
        "layer.losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p7VKJj3JOCEk"
      },
      "source": [
        "Note that if the `NestedLayer` were a non-Keras `tf.Module` instead, variables would still be tracked but regularization losses would not be automatically tracked, so you would have to explicitly track them separately."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FsTgnydkdezQ"
      },
      "source": [
        "### Guidance on variable names\n",
        "\n",
        "Explicit `tf.Variable` calls and Keras layers use a different layer name / variable name autogeneration mechanism than you may be used to from the combination of `get_variable` and `variable_scopes`. Although the shim will make your variable names match for variables created by `get_variable` even when going from TF1.x graphs to TF2 eager execution & `tf.function`, it cannot guarantee the same for the variable names generated for `tf.Variable` calls and Keras layers that you embed within your method decorators. It is even possible for multiple variables to share the same name in TF2 eager execution and `tf.function`.\n",
        "\n",
        "You should take special care with this when following the sections on validating correctness and mapping TF1.x checkpoints later on in this guide."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CaP7fxoUWfMm"
      },
      "source": [
        "### Using `tf.compat.v1.make_template` in the decorated method\n",
        "\n",
        "**It is highly recommended you directly use `tf.compat.v1.keras.utils.track_tf1_style_variables` instead of using `tf.compat.v1.make_template`, as it is a thinner layer on top of TF2**. \n",
        "\n",
        "Follow the guidance in this section for prior TF1.x code that was already relying on `tf.compat.v1.make_template`.\n",
        "\n",
        "Because `tf.compat.v1.make_template` wraps code that uses `get_variable`, the `track_tf1_style_variables`  decorator allows you to use these templates in layer calls and successfully track the weights and regularization losses.\n",
        "\n",
        "However, do make sure to call `make_template` only once and then reuse the same template in each layer call. Otherwise, a new template will be created each time you call the layer along with a new set of variables.\n",
        "\n",
        "For example,"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iHEQN8z44dbK"
      },
      "outputs": [],
      "source": [
        "class CompatV1TemplateScaleByY(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "    def my_op(x, scalar_name):\n",
        "      var1 = tf.compat.v1.get_variable(scalar_name,\n",
        "                            shape=[],\n",
        "                            regularizer=tf.compat.v1.keras.regularizers.L2(),\n",
        "                            initializer=tf.compat.v1.constant_initializer(1.5))\n",
        "      return x * var1\n",
        "    self.scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, scalar_name='y')\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs):\n",
        "    with tf.compat.v1.variable_scope('layer'):\n",
        "      # Using a scope ensures the `scale_by_y` name will not be incremented\n",
        "      # for each instantiation of the layer.\n",
        "      return self.scale_by_y(inputs)\n",
        "\n",
        "layer = CompatV1TemplateScaleByY()\n",
        "\n",
        "out = layer(tf.ones(shape=(2, 3)))\n",
        "print(\"weights:\", layer.weights)\n",
        "print(\"regularization loss:\", layer.losses)\n",
        "print(\"output:\", out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3vKTJ7IsTEe8"
      },
      "source": [
        "Warning: Avoid sharing the same `make_template`-created template across multiple layer instances as it may break the variable and regularization loss tracking mechanisms of the shim decorator. Additionally, if you plan to use the same `make_template` name inside of multiple layer instances then you should nest the created template's usage inside of a `variable_scope`. If not, the generated name for the template's `variable_scope` will increment with each new instance of the layer. This could alter the weight names in unexpected ways."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P4E3-XPhWD2N"
      },
      "source": [
        "## Incremental migration to Native TF2\n",
        "\n",
        "As mentioned earlier, `track_tf1_style_variables` allows you to mix TF2-style object-oriented `tf.Variable`/`tf.keras.layers.Layer`/`tf.Module` usage with legacy `tf.compat.v1.get_variable`/`tf.compat.v1.layers`-style usage inside of the same decorated module/layer.\n",
        "\n",
        "This means that after you have made your TF1.x model fully-TF2-compatible, you can write all new model components with native (non-`tf.compat.v1`) TF2 APIs and have them interoperate with your older code.\n",
        "\n",
        "However, if you continue to modify your older model components, you may also choose to incrementally switch your legacy-style `tf.compat.v1` usage over to the purely-native object-oriented APIs that are recommended for newly written TF2 code.\n",
        "\n",
        "`tf.compat.v1.get_variable` usage can be replaced with either `self.add_weight` calls if you are decorating a Keras layer/model, or with `tf.Variable` calls if you are decorating Keras objects or `tf.Module`s.\n",
        "\n",
        "Both functional-style and object-oriented `tf.compat.v1.layers` can generally be replaced with the equivalent `tf.keras.layers` layer with no argument changes required.\n",
        "\n",
        "You may also consider chunks parts of your model or common patterns into individual layers/modules during your incremental move to purely-native APIs, which may themselves use `track_tf1_style_variables`.\n",
        "\n",
        "### A note on Slim and contrib.layers\n",
        "\n",
        "A large amount of older TF 1.x code uses the [Slim](https://ai.googleblog.com/2016/08/tf-slim-high-level-library-to-define.html) library, which was packaged with TF 1.x as `tf.contrib.layers`. Converting code using Slim to native TF 2 is more involved than converting `v1.layers`. In fact, it may make sense to convert your Slim code to `v1.layers` first, then convert to Keras. Below is some general guidance for converting Slim code.\n",
        "\n",
        "- Ensure all arguments are explicit. Remove `arg_scopes` if possible. If you still need to use them, split `normalizer_fn` and `activation_fn` into their own layers.\n",
        "- Separable conv layers map to one or more different Keras layers (depthwise, pointwise, and separable Keras layers).\n",
        "- Slim and `v1.layers` have different argument names and default values.\n",
        "- Note that some arguments have different scales."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RFoULo-gazit"
      },
      "source": [
        "### Migration to Native TF2 ignoring checkpoint compatibility\n",
        "\n",
        "The following code sample demonstrates an incremental move of a model to purely-native APIs without considering checkpoint compatibility."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dPO9YJsb6r-D"
      },
      "outputs": [],
      "source": [
        "class CompatModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs, training=None):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = tf.compat.v1.layers.conv2d(\n",
        "          inputs, 3, 3,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      out = tf.compat.v1.layers.flatten(out)\n",
        "      out = tf.compat.v1.layers.dropout(out, training=training)\n",
        "      out = tf.compat.v1.layers.dense(\n",
        "          out, self.units,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      return out\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fp16xK6Oa8k9"
      },
      "source": [
        "Next, replace the `compat.v1` APIs with their native object-oriented equivalents in a piecewise manner. Start by switching the convolution layer to a Keras object created in the layer constructor."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LOj1Swe16so3"
      },
      "outputs": [],
      "source": [
        "class PartiallyMigratedModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "    self.conv_layer = tf.keras.layers.Conv2D(\n",
        "      3, 3,\n",
        "      kernel_regularizer=\"l2\")\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs, training=None):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = self.conv_layer(inputs)\n",
        "      out = tf.compat.v1.layers.flatten(out)\n",
        "      out = tf.compat.v1.layers.dropout(out, training=training)\n",
        "      out = tf.compat.v1.layers.dense(\n",
        "          out, self.units,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      return out\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kzJF0H0sbce8"
      },
      "source": [
        "Use the [`v1.keras.utils.DeterministicRandomTestTool`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/keras/utils/DeterministicRandomTestTool) class to verify that this incremental change leaves the model with the same behavior as before."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MTJq0qW9_Tz2"
      },
      "outputs": [],
      "source": [
        "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n",
        "with random_tool.scope():\n",
        "  tf.keras.utils.set_random_seed(42)\n",
        "  layer = CompatModel(10)\n",
        "\n",
        "  inputs = tf.random.normal(shape=(10, 5, 5, 5))\n",
        "  original_output = layer(inputs)\n",
        "\n",
        "  # Grab the regularization loss as well\n",
        "  original_regularization_loss = tf.math.add_n(layer.losses)\n",
        "\n",
        "print(original_regularization_loss)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X4Wq3wuaHjEV"
      },
      "outputs": [],
      "source": [
        "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n",
        "with random_tool.scope():\n",
        "  tf.keras.utils.set_random_seed(42)\n",
        "  layer = PartiallyMigratedModel(10)\n",
        "\n",
        "  inputs = tf.random.normal(shape=(10, 5, 5, 5))\n",
        "  migrated_output = layer(inputs)\n",
        "\n",
        "  # Grab the regularization loss as well\n",
        "  migrated_regularization_loss = tf.math.add_n(layer.losses)\n",
        "\n",
        "print(migrated_regularization_loss)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mMMXS7EHjvCy"
      },
      "outputs": [],
      "source": [
        "# Verify that the regularization loss and output both match\n",
        "np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())\n",
        "np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RMxiMVFwbiQy"
      },
      "source": [
        "You have now replaced all of the individual `compat.v1.layers` with native Keras layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3dFCnyYc9DrX"
      },
      "outputs": [],
      "source": [
        "class NearlyFullyNativeModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "    self.conv_layer = tf.keras.layers.Conv2D(\n",
        "      3, 3,\n",
        "      kernel_regularizer=\"l2\")\n",
        "    self.flatten_layer = tf.keras.layers.Flatten()\n",
        "    self.dense_layer = tf.keras.layers.Dense(\n",
        "      self.units,\n",
        "      kernel_regularizer=\"l2\")\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = self.conv_layer(inputs)\n",
        "      out = self.flatten_layer(out)\n",
        "      out = self.dense_layer(out)\n",
        "      return out\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QGPqEjkGHgar"
      },
      "outputs": [],
      "source": [
        "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n",
        "with random_tool.scope():\n",
        "  tf.keras.utils.set_random_seed(42)\n",
        "  layer = NearlyFullyNativeModel(10)\n",
        "\n",
        "  inputs = tf.random.normal(shape=(10, 5, 5, 5))\n",
        "  migrated_output = layer(inputs)\n",
        "\n",
        "  # Grab the regularization loss as well\n",
        "  migrated_regularization_loss = tf.math.add_n(layer.losses)\n",
        "\n",
        "print(migrated_regularization_loss)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uAs60eCdj6x_"
      },
      "outputs": [],
      "source": [
        "# Verify that the regularization loss and output both match\n",
        "np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())\n",
        "np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oA6viSo3bo3y"
      },
      "source": [
        "Finally, remove both any remaining (no-longer-needed) `variable_scope` usage and the `track_tf1_style_variables` decorator itself.\n",
        "\n",
        "You are now left with a version of the model that uses entirely native APIs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mIHpHWIRDunU"
      },
      "outputs": [],
      "source": [
        "class FullyNativeModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, units, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.units = units\n",
        "    self.conv_layer = tf.keras.layers.Conv2D(\n",
        "      3, 3,\n",
        "      kernel_regularizer=\"l2\")\n",
        "    self.flatten_layer = tf.keras.layers.Flatten()\n",
        "    self.dense_layer = tf.keras.layers.Dense(\n",
        "      self.units,\n",
        "      kernel_regularizer=\"l2\")\n",
        "\n",
        "  def call(self, inputs):\n",
        "    out = self.conv_layer(inputs)\n",
        "    out = self.flatten_layer(out)\n",
        "    out = self.dense_layer(out)\n",
        "    return out\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ttAmiCvLHW54"
      },
      "outputs": [],
      "source": [
        "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n",
        "with random_tool.scope():\n",
        "  tf.keras.utils.set_random_seed(42)\n",
        "  layer = FullyNativeModel(10)\n",
        "\n",
        "  inputs = tf.random.normal(shape=(10, 5, 5, 5))\n",
        "  migrated_output = layer(inputs)\n",
        "\n",
        "  # Grab the regularization loss as well\n",
        "  migrated_regularization_loss = tf.math.add_n(layer.losses)\n",
        "\n",
        "print(migrated_regularization_loss)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ym5DYtT4j7e3"
      },
      "outputs": [],
      "source": [
        "# Verify that the regularization loss and output both match\n",
        "np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())\n",
        "np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oX4pdrzycIsa"
      },
      "source": [
        "### Maintaining checkpoint compatibility during migration to Native TF2\n",
        "\n",
        "The above migration process to native TF2 APIs changed both the variable names (as Keras APIs produce very different weight names), and the object-oriented paths that point to different weights in the model. The impact of these changes is that they will have broken both any existing TF1-style name-based checkpoints or TF2-style object-oriented checkpoints.\n",
        "\n",
        "However, in some cases, you might be able to take your original name-based checkpoint and find a mapping of the variables to their new names with approaches like the one detailed in the [Reusing TF1.x checkpoints guide](./migrating_checkpoints.ipynb).\n",
        "\n",
        "Some tips to making this feasible are as follows:\n",
        "- Variables still all have a `name` argument you can set.\n",
        "- Keras models also take a `name` argument as which they set as the prefix for their variables.\n",
        "- The `v1.name_scope` function can be used to set variable name prefixes. This is very different from `tf.variable_scope`. It only affects names, and doesn't track variables and reuse.\n",
        "\n",
        "With the above pointers in mind, the following code samples demonstrate a workflow you can adapt to your code to incrementally update part of a model while simultaneously updating checkpoints.\n",
        "\n",
        "Note: Due to the complexity of variable naming with Keras layers, this is not guaranteed to work for all use cases."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EFmMY3dcx3mR"
      },
      "source": [
        "1. Begin by switching functional-style `tf.compat.v1.layers` to their object-oriented versions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cRxCFmNjl2ta"
      },
      "outputs": [],
      "source": [
        "class FunctionalStyleCompatModel(tf.keras.layers.Layer):\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs, training=None):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = tf.compat.v1.layers.conv2d(\n",
        "          inputs, 3, 3,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      out = tf.compat.v1.layers.conv2d(\n",
        "          out, 4, 4,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      out = tf.compat.v1.layers.conv2d(\n",
        "          out, 5, 5,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      return out\n",
        "\n",
        "layer = FunctionalStyleCompatModel()\n",
        "layer(tf.ones(shape=(10, 10, 10, 10)))\n",
        "[v.name for v in layer.weights]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QvzUyXxjydAd"
      },
      "source": [
        "2. Next, assign the compat.v1.layer objects and any variables created by `compat.v1.get_variable` as properties of the `tf.keras.layers.Layer`/`tf.Module` object whose method is decorated with `track_tf1_style_variables` (note that any object-oriented TF2 style checkpoints will now save out both a path by variable name and the new object-oriented path)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "02jMQkJFmFwl"
      },
      "outputs": [],
      "source": [
        "class OOStyleCompatModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.conv_1 = tf.compat.v1.layers.Conv2D(\n",
        "          3, 3,\n",
        "          kernel_regularizer=\"l2\")\n",
        "    self.conv_2 = tf.compat.v1.layers.Conv2D(\n",
        "          4, 4,\n",
        "          kernel_regularizer=\"l2\")\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs, training=None):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = self.conv_1(inputs)\n",
        "      out = self.conv_2(out)\n",
        "      out = tf.compat.v1.layers.conv2d(\n",
        "          out, 5, 5,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      return out\n",
        "\n",
        "layer = OOStyleCompatModel()\n",
        "layer(tf.ones(shape=(10, 10, 10, 10)))\n",
        "[v.name for v in layer.weights]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8evFpd8Nq63v"
      },
      "source": [
        "3. Resave a loaded checkpoint at this point to save out paths both by the variable name (for compat.v1.layers), or by the object-oriented object graph."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7neFr-9pqmJX"
      },
      "outputs": [],
      "source": [
        "weights = {v.name: v for v in layer.weights}\n",
        "assert weights['model/conv2d/kernel:0'] is layer.conv_1.kernel\n",
        "assert weights['model/conv2d_1/bias:0'] is layer.conv_2.bias"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pvsi743Xh9wn"
      },
      "source": [
        "4. You can now swap out the object-oriented `compat.v1.layers` for native Keras layers while still being able to load the recently-saved checkpoint. Ensure that you preserve variable names for the remaining `compat.v1.layers` by still recording the auto-generated `variable_scopes` of the replaced layers. These switched layers/variables will now only use the object attribute path to the variables in the checkpoint instead of the variable name path.\n",
        "\n",
        "In general, you can replace usage of `compat.v1.get_variable` in variables attached to properties by:\n",
        "\n",
        "* Switching them to using `tf.Variable`, **OR**  \n",
        "* Updating them by using [`tf.keras.layers.Layer.add_weight`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_weight). Note that if you are not switching all layers in one go this may change auto-generated layer/variable naming for the remaining `compat.v1.layers` that are missing a `name` argument. If that is the case, you must keep the variable names for remaining `compat.v1.layers` the same by manually opening and closing a `variable_scope` corresponding to the removed `compat.v1.layer`'s generated scope name. Otherwise the paths from existing checkpoints may conflict and checkpoint loading will behave incorrectly.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NbixtIW-maoH"
      },
      "outputs": [],
      "source": [
        "def record_scope(scope_name):\n",
        "  \"\"\"Record a variable_scope to make sure future ones get incremented.\"\"\"\n",
        "  with tf.compat.v1.variable_scope(scope_name):\n",
        "    pass\n",
        "\n",
        "class PartiallyNativeKerasLayersModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.conv_1 = tf.keras.layers.Conv2D(\n",
        "          3, 3,\n",
        "          kernel_regularizer=\"l2\")\n",
        "    self.conv_2 = tf.keras.layers.Conv2D(\n",
        "          4, 4,\n",
        "          kernel_regularizer=\"l2\")\n",
        "\n",
        "  @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
        "  def call(self, inputs, training=None):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = self.conv_1(inputs)\n",
        "      record_scope('conv2d') # Only needed if follow-on compat.v1.layers do not pass a `name` arg\n",
        "      out = self.conv_2(out)\n",
        "      record_scope('conv2d_1') # Only needed if follow-on compat.v1.layers do not pass a `name` arg\n",
        "      out = tf.compat.v1.layers.conv2d(\n",
        "          out, 5, 5,\n",
        "          kernel_regularizer=\"l2\")\n",
        "      return out\n",
        "\n",
        "layer = PartiallyNativeKerasLayersModel()\n",
        "layer(tf.ones(shape=(10, 10, 10, 10)))\n",
        "[v.name for v in layer.weights]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2eaPpevGs3dA"
      },
      "source": [
        "Saving a checkpoint out at this step after constructing the variables will make it contain ***only*** the currently-available object paths. \n",
        "\n",
        "Ensure you record the scopes of the removed `compat.v1.layers` to preserve the auto-generated weight names for the remaining `compat.v1.layers`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EK7vtWBprObA"
      },
      "outputs": [],
      "source": [
        "weights = set(v.name for v in layer.weights)\n",
        "assert 'model/conv2d_2/kernel:0' in weights\n",
        "assert 'model/conv2d_2/bias:0' in weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DQ5-SfmWFTvY"
      },
      "source": [
        "5. Repeat the above steps until you have replaced all the `compat.v1.layers` and `compat.v1.get_variable`s in your model with fully-native equivalents."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PA1d2POtnTQa"
      },
      "outputs": [],
      "source": [
        "class FullyNativeKerasLayersModel(tf.keras.layers.Layer):\n",
        "\n",
        "  def __init__(self, *args, **kwargs):\n",
        "    super().__init__(*args, **kwargs)\n",
        "    self.conv_1 = tf.keras.layers.Conv2D(\n",
        "          3, 3,\n",
        "          kernel_regularizer=\"l2\")\n",
        "    self.conv_2 = tf.keras.layers.Conv2D(\n",
        "          4, 4,\n",
        "          kernel_regularizer=\"l2\")\n",
        "    self.conv_3 = tf.keras.layers.Conv2D(\n",
        "          5, 5,\n",
        "          kernel_regularizer=\"l2\")\n",
        "\n",
        "\n",
        "  def call(self, inputs, training=None):\n",
        "    with tf.compat.v1.variable_scope('model'):\n",
        "      out = self.conv_1(inputs)\n",
        "      out = self.conv_2(out)\n",
        "      out = self.conv_3(out)\n",
        "      return out\n",
        "\n",
        "layer = FullyNativeKerasLayersModel()\n",
        "layer(tf.ones(shape=(10, 10, 10, 10)))\n",
        "[v.name for v in layer.weights]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vZejG7rTsTb6"
      },
      "source": [
        "Remember to test to make sure the newly updated checkpoint still behaves as you expect. Apply the techniques described in the [validate numerical correctness guide](./validate_correctness.ipynb) at every incremental step of this process to ensure your migrated code runs correctly."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ewi_h-cs6n-I"
      },
      "source": [
        "## Handling TF1.x to TF2 behavior changes not covered by the modeling shims\n",
        "\n",
        "The modeling shims described in this guide can make sure that variables, layers, and regularization losses created with `get_variable`, `tf.compat.v1.layers`, and `variable_scope` semantics continue to work as before when using eager execution and `tf.function`, without having to rely on collections.\n",
        "\n",
        "This does not cover ***all*** TF1.x-specific semantics that your model forward passes may be relying on. In some cases, the shims might be insufficient to get your model forward pass running in TF2 on their own. Read the [TF1.x vs TF2 behaviors guide](./tf1_vs_tf2) to learn more about the behavioral differences between TF1.x and TF2."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "model_mapping.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
